package com.ideaaedi.mybatis.data.security.interceptor;

import com.ideaaedi.mybatis.data.security.support.EncryptInfoHolder;
import com.ideaaedi.mybatis.data.security.support.EncryptParser;
import com.ideaaedi.mybatis.data.security.support.PojoCloneable;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.CachingExecutor;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;

/**
 * 自定义Mybatis插件 - 实现数据库(入库/出库)敏感字段脱敏
 *
 * @author JustryDeng
 * @since 2021/2/10 22:40:59
 */
@Intercepts(value = {
        // 入
        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
        // 出
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
        // other
        @Signature(type = Executor.class, method = "close", args = {boolean.class})
})

public class MybatisEncryptPlugin implements Interceptor {
    
    private static final Logger log = LoggerFactory.getLogger(MybatisEncryptPlugin.class);
    
    @SuppressWarnings("rawtypes")
    private static final ThreadLocal<Map<PojoCloneable, List<PojoCloneable>>> originPojoAndClonePojoListMapThreadLocal = new ThreadLocal<>();
    
    private final EncryptParser encryptParser;
    
    public MybatisEncryptPlugin(EncryptParser encryptParser) {
        this.encryptParser = encryptParser;
    }
    
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object target = invocation.getTarget();
        if (!(target instanceof Executor)) {
            throw new UnsupportedOperationException("Base 'MybatisEncryptPlugin.plugin' Setting, curr plugin only "
                    + "support Executor.class");
        }
        
        Method method = invocation.getMethod();
        String name = method.getName();
        if ("update".equals(name) || "query".equals(name)) {
            MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
            EncryptInfoHolder encryptInfoHolder = encryptParser.determineEncryptInfoHolder(mappedStatement);
            String mappedStatementId = mappedStatement.getId();
            // mybatis-plus的分页count
            if (encryptInfoHolder == null && mappedStatementId.endsWith("_mpCount")) {
                String rawMappedStatementId = mappedStatementId.substring(0,
                        mappedStatementId.length() - "_mpCount".length());
                encryptInfoHolder = encryptParser.determineEncryptInfoHolder(rawMappedStatementId);
                if (encryptInfoHolder != null) {
                    EncryptParser.EncryptOop encryptOop = encryptParser.getEncryptOop();
                    if (!encryptOop.existMpCount(rawMappedStatementId)) {
                        encryptOop.putMpCount(rawMappedStatementId, mappedStatementId);
                    }
                }
            }
            if (encryptInfoHolder == null) {
                log.debug("mappedStatementId [{}] match non-encryptInfoHolder. Operate database directly.",
                        mappedStatementId);
                return invocation.proceed();
            }
            boolean needEncrypt = encryptInfoHolder.isNeedEncrypt();
            //noinspection rawtypes
            Map<PojoCloneable, List<PojoCloneable>> originPojoAndClonePojoListMap = new HashMap<>(32);
            if (needEncrypt) {
                // 加密
                log.debug("mappedStatementId [{}] need encrypt.", mappedStatementId);
                Object rowParameter = invocation.getArgs()[1];
                invocation.getArgs()[1] = encryptParser.doEncrypt(mappedStatement, rowParameter, encryptInfoHolder,
                        originPojoAndClonePojoListMap);
            }
            Object rowResult = invocation.proceed();
            // 处理回填数据1
            handleBackFillData(target, originPojoAndClonePojoListMap);
            if (encryptInfoHolder.isNeedDecrypt()) {
                // 解密
                log.debug("mappedStatementId [{}] need decrypt.", mappedStatementId);
                rowResult = encryptParser.doDecrypt(mappedStatement, rowResult, encryptInfoHolder);
            }
            return rowResult;
        } else if ("close".equals(name)) {
            Object proceed = invocation.proceed();
            //noinspection rawtypes
            Map<PojoCloneable, List<PojoCloneable>> threadLocalPojoCloneableListMap = originPojoAndClonePojoListMapThreadLocal.get();
            if (threadLocalPojoCloneableListMap != null) {
                // 处理可能存在的回填参数
                doBackFill(threadLocalPojoCloneableListMap);
                originPojoAndClonePojoListMapThreadLocal.remove();
            }
            return proceed;
        } else {
            return invocation.proceed();
        }
    }
    
    @Override
    public Object plugin(Object target) {
        // 因为上面@Signature指定了插件生效范围,即：type为Executor.class StatementHandler.class所以这里只处理这些case就行了
        if (target instanceof Executor) {
            return Plugin.wrap(target, this);
        }
        return target;
    }

    @Override
    public void setProperties(Properties properties) {
        // ignore
    }
    
    /**
     * 处理回填参数
     */
    private void handleBackFillData(Object target, @SuppressWarnings("rawtypes") Map<PojoCloneable, List<PojoCloneable>> originPojoAndClonePojoListMap) {
        boolean currIsCachingExecutor = target instanceof CachingExecutor;
        // 如果不是CachingExecutor在执行，那么马上处理回填参数， 如果是CachingExecutor在执行，那么延迟至close方法再处理回填参数
        if (!currIsCachingExecutor)  {
            // 处理可能存在的回填参数
            doBackFill(originPojoAndClonePojoListMap);
            originPojoAndClonePojoListMap.clear();
        } else if (!originPojoAndClonePojoListMap.isEmpty()) {
            //noinspection rawtypes
            Map<PojoCloneable, List<PojoCloneable>> threadLocalPojoCloneableListMap = originPojoAndClonePojoListMapThreadLocal.get();
            if (threadLocalPojoCloneableListMap == null) {
                originPojoAndClonePojoListMapThreadLocal.set(new HashMap<>(32));
                threadLocalPojoCloneableListMap = originPojoAndClonePojoListMapThreadLocal.get();
            }
            threadLocalPojoCloneableListMap.putAll(originPojoAndClonePojoListMap);
        }
    }
    
    /**
     * 数据回填
     * <p>
     * 将回填至clonePojo中的字段数据（如：自增id）写入originPojo中的对应字段里
     * <p>
     *     bugfix：<a href="https://gitee.com/JustryDeng/mybatis-data-security/issues/I7V24O">issuse-I7V24O</a>
     * </p>
     *
     * @param originPojoAndClonePojoListMap 源对象-克隆对象 映射map
     */
    private <T extends PojoCloneable<T>> void doBackFill(@SuppressWarnings("rawtypes")
                                                         Map<PojoCloneable, List<PojoCloneable>> originPojoAndClonePojoListMap) {
        if (CollectionUtils.isEmpty(originPojoAndClonePojoListMap)) {
            return;
        }
        originPojoAndClonePojoListMap.forEach((originPojo, clonePojo) -> {
            if (originPojo == null || clonePojo == null) {
                return;
            }
            //noinspection unchecked
            originPojo.handleNullPropertyByClonePojo(clonePojo);
        });
    }
    
}
